// RUN: spu-opt --optimize-maxpool --split-input-file %s | FileCheck %s

func.func @main(%arg0: tensor<129x24x24x16x!pphlo.secret<f32>>, %arg1: tensor<129x23x23x16x!pphlo.secret<f32>>) -> (tensor<129x23x23x16x!pphlo.secret<f32>>, tensor<129x24x24x16x!pphlo.secret<f32>>) {
    %0 = pphlo.constant dense<0xFF800000> : tensor<f32>
    %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
    %2 = pphlo.convert %0 : (tensor<f32>) -> tensor<!pphlo.secret<f32>>
    %3 = pphlo.convert %1 : (tensor<f32>) -> tensor<!pphlo.secret<f32>>
    //CHECK: pphlo.argmax %arg0
    %4 = "pphlo.reduce_window"(%arg0, %2) ({
    ^bb0(%arg2: tensor<!pphlo.secret<f32>>, %arg3: tensor<!pphlo.secret<f32>>):
      %6 = pphlo.maximum %arg2, %arg3 : tensor<!pphlo.secret<f32>>
      pphlo.return %6 : tensor<!pphlo.secret<f32>>
    }) {base_dilations = array<i64: 1, 1, 1, 1>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 2, 2, 1>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<129x24x24x16x!pphlo.secret<f32>>, tensor<!pphlo.secret<f32>>) -> tensor<129x23x23x16x!pphlo.secret<f32>>
    //CHECK-NOT: pphlo.select_and_scatter
    //CHECK : pphlo.maxpool_scatter
    %5 = "pphlo.select_and_scatter"(%arg0, %arg1, %3) ({
    ^bb0(%arg2: tensor<!pphlo.secret<f32>>, %arg3: tensor<!pphlo.secret<f32>>):
      %6 = pphlo.greater_equal %arg2, %arg3 : (tensor<!pphlo.secret<f32>>, tensor<!pphlo.secret<f32>>) -> tensor<!pphlo.secret<i1>>
      pphlo.return %6 : tensor<!pphlo.secret<i1>>
    }, {
    ^bb0(%arg2: tensor<!pphlo.secret<f32>>, %arg3: tensor<!pphlo.secret<f32>>):
      %6 = pphlo.add %arg2, %arg3 : tensor<!pphlo.secret<f32>>
      pphlo.return %6 : tensor<!pphlo.secret<f32>>
    }) {window_dimensions = array<i64: 1, 2, 2, 1>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<129x24x24x16x!pphlo.secret<f32>>, tensor<129x23x23x16x!pphlo.secret<f32>>, tensor<!pphlo.secret<f32>>) -> tensor<129x24x24x16x!pphlo.secret<f32>>

    return %4, %5 : tensor<129x23x23x16x!pphlo.secret<f32>>, tensor<129x24x24x16x!pphlo.secret<f32>>
}

// -----

func.func @main(%arg0: tensor<128x2x2x256x!pphlo.secret<f32>>, %arg1: tensor<128x1x1x256x!pphlo.secret<f32>>) -> (tensor<128x2x2x256x!pphlo.secret<f32>>, tensor<128x2x2x256x!pphlo.secret<f32>>) {
    %0 = pphlo.constant dense<0.000000e+00> : tensor<128x2x2x256xf32>
    %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
    %2 = pphlo.convert %1 : (tensor<f32>) -> tensor<!pphlo.secret<f32>>
    %3 = pphlo.maximum %arg0, %0 : (tensor<128x2x2x256x!pphlo.secret<f32>>, tensor<128x2x2x256xf32>) -> tensor<128x2x2x256x!pphlo.secret<f32>>
    // CHECK: pphlo.select_and_scatter
    %4 = "pphlo.select_and_scatter"(%arg0, %arg1, %2) ({
    ^bb0(%arg2: tensor<!pphlo.secret<f32>>, %arg3: tensor<!pphlo.secret<f32>>):
      %5 = pphlo.greater_equal %arg2, %arg3 : (tensor<!pphlo.secret<f32>>, tensor<!pphlo.secret<f32>>) -> tensor<!pphlo.secret<i1>>
      pphlo.return %5 : tensor<!pphlo.secret<i1>>
    }, {
    ^bb0(%arg2: tensor<!pphlo.secret<f32>>, %arg3: tensor<!pphlo.secret<f32>>):
      %5 = pphlo.add %arg2, %arg3 : tensor<!pphlo.secret<f32>>
      pphlo.return %5 : tensor<!pphlo.secret<f32>>
    }) {window_dimensions = array<i64: 1, 2, 2, 1>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<128x2x2x256x!pphlo.secret<f32>>, tensor<128x1x1x256x!pphlo.secret<f32>>, tensor<!pphlo.secret<f32>>) -> tensor<128x2x2x256x!pphlo.secret<f32>>

    return %3, %4 : tensor<128x2x2x256x!pphlo.secret<f32>>, tensor<128x2x2x256x!pphlo.secret<f32>>
}
